// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_PARAMS_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_PARAMS_H_

#include <ac_std_float.h>
#include <mc_connections.h>

struct Params {
  int M0;
  int P1;
  int N1;
  int M1;
  int P2;
  int INPUT_OFFSET;
  int WEIGHT_OFFSET;
  int OUTPUT_OFFSET;
  bool SOFTMAX;
  ac::bfloat16 SCALE;
  bool TRANSPOSE;
  int VECTOR_OFFSET;
  bool VEC_OP;
  bool VEC_SUB;
  bool VEC_SQUARE;
  bool VEC_REDUCE;
  bool CONST_SCALE;
  int VEC_SCALE_OFFSET;
  int VEC_SUB_OFFSET;
  bool RELU;  // TODO: support GELU
  // 2 levels of memory hierarchy
  int loops[2][6];
  int inputXLoopIndex[2];
  int inputYLoopIndex[2];
  int reductionLoopIndex[2];
  int weightLoopIndex[2];
  int fxIndex;
  int fyIndex;
  int weightReuseIndex[2];
  int STRIDE;

  // support for multi-head operations
  bool CONCAT_HEAD;
  bool SPLIT_HEAD;
  int HEAD_SZ_LG2;

  // width is the number of bits in the struct
  //    12 ints (32 bits)
  // +  10  bools (1 bit)
  // +  1  bf16 (16 bits)
  // +  12 ints (32 bits)
  static const unsigned int width =
      12 * 32 + 10 * 1 + ac::bfloat16::width + 25 * 32;

  template <unsigned int Size>
  void Marshall(Marshaller<Size>& m) {
    m& M0;
    m& P1;
    m& N1;
    m& M1;
    m& P2;
    m& INPUT_OFFSET;
    m& WEIGHT_OFFSET;
    m& OUTPUT_OFFSET;
    m& SOFTMAX;
    m& SCALE;
    m& TRANSPOSE;
    m& VECTOR_OFFSET;
    m& VEC_OP;
    m& VEC_SUB;
    m& VEC_SQUARE;
    m& VEC_REDUCE;
    m& CONST_SCALE;
    m& VEC_SCALE_OFFSET;
    m& VEC_SUB_OFFSET;
    m& RELU;
    for (int i = 0; i < 2; i++) {
      for (int j = 0; j < 6; j++) {
        m& loops[i][j];
      }
    }
    for (int i = 0; i < 2; i++) {
      m& inputXLoopIndex[i];
    }
    for (int i = 0; i < 2; i++) {
      m& inputYLoopIndex[i];
    }
    for (int i = 0; i < 2; i++) {
      m& reductionLoopIndex[i];
    }
    for (int i = 0; i < 2; i++) {
      m& weightLoopIndex[i];
    }
    m& fxIndex;
    m& fyIndex;
    for (int i = 0; i < 2; i++) {
      m& weightReuseIndex[i];
    }
    m& STRIDE;
    m& CONCAT_HEAD;
    m& SPLIT_HEAD;
    m& HEAD_SZ_LG2;
  }

  inline friend void sc_trace(sc_trace_file* tf, const Params& params,
                              const std::string& NAME) {
    sc_trace(tf, params.M0, NAME + ".M0");
    sc_trace(tf, params.P1, NAME + ".P1");
    sc_trace(tf, params.N1, NAME + ".N1");
    sc_trace(tf, params.M1, NAME + ".M1");
    sc_trace(tf, params.P2, NAME + ".P2");
    sc_trace(tf, params.INPUT_OFFSET, NAME + ".INPUT_OFFSET");
    sc_trace(tf, params.WEIGHT_OFFSET, NAME + ".WEIGHT_OFFSET");
    sc_trace(tf, params.OUTPUT_OFFSET, NAME + ".OUTPUT_OFFSET");
    sc_trace(tf, params.SOFTMAX, NAME + ".SOFTMAX");
    // sc_trace(tf, params.SCALE, NAME + ".SCALE");
    sc_trace(tf, params.TRANSPOSE, NAME + ".TRANSPOSE");
    sc_trace(tf, params.VECTOR_OFFSET, NAME + ".VECTOR_OFFSET");
    sc_trace(tf, params.VEC_OP, NAME + ".VEC_OP");
    sc_trace(tf, params.VEC_SUB, NAME + ".VEC_SUB");
    sc_trace(tf, params.VEC_SQUARE, NAME + ".VEC_SQUARE");
    sc_trace(tf, params.VEC_REDUCE, NAME + ".VEC_REDUCE");
    sc_trace(tf, params.CONST_SCALE, NAME + ".CONST_SCALE");
    sc_trace(tf, params.VEC_SCALE_OFFSET, NAME + ".VEC_SCALE_OFFSET");
    sc_trace(tf, params.VEC_SUB_OFFSET, NAME + ".VEC_SUB_OFFSET");
    sc_trace(tf, params.RELU, NAME + ".RELU");
  }

  inline friend std::ostream& operator<<(ostream& os, const Params& params) {
    os << params.M0 << " ";
    os << params.P1 << " ";
    os << params.N1 << " ";
    os << params.M1 << " ";
    os << params.P2 << " ";
    os << params.INPUT_OFFSET << " ";
    os << params.WEIGHT_OFFSET << " ";
    os << params.OUTPUT_OFFSET << " ";
    os << params.SOFTMAX << " ";
    os << params.SCALE << " ";
    os << params.TRANSPOSE << " ";
    os << params.VECTOR_OFFSET << " ";
    os << params.VEC_OP << " ";
    os << params.VEC_SUB << " ";
    os << params.VEC_SQUARE << " ";
    os << params.VEC_REDUCE << " ";
    os << params.CONST_SCALE << " ";
    os << params.VEC_SCALE_OFFSET << " ";
    os << params.VEC_SUB_OFFSET << " ";
    os << params.RELU << " ";

    return os;
  }
};

struct VectorParams {
  int VECTOR_OFFSET;
  int OUTPUT_OFFSET;

  int loops[2][3];
  int xLoopIndex[2];
  int yLoopIndex[2];
  int kLoopIndex[2];

  bool useMatrixProcessorOutput;
  bool SCALE;
  ac::bfloat16 SCALE_FACTOR;
  bool RELU;
  bool SOFTMAX;
  bool LAYER_NORM;

  bool SPLIT_HEAD;
  int HEAD_SZ_LG2;

  static const unsigned int width =
      2 * 32 + 12 * 32 + 5 * 1 + ac::bfloat16::width + 1 + 32;

  template <unsigned int Size>
  void Marshall(Marshaller<Size>& m) {
    m& VECTOR_OFFSET;
    m& OUTPUT_OFFSET;

    for (int i = 0; i < 2; i++) {
      for (int j = 0; j < 3; j++) {
        m& loops[i][j];
      }
    }
    for (int i = 0; i < 2; i++) {
      m& xLoopIndex[i];
    }
    for (int i = 0; i < 2; i++) {
      m& yLoopIndex[i];
    }
    for (int i = 0; i < 2; i++) {
      m& kLoopIndex[i];
    }

    m& useMatrixProcessorOutput;
    m& SCALE;
    m& SCALE_FACTOR;
    m& RELU;
    m& SOFTMAX;
    m& LAYER_NORM;

    m& SPLIT_HEAD;
    m& HEAD_SZ_LG2;
  }

  inline friend void sc_trace(sc_trace_file* tf, const VectorParams& params,
                              const std::string& NAME) {
    // TODO
  }

  inline friend std::ostream& operator<<(ostream& os,
                                         const VectorParams& params) {
    // TODO
    return os;
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_PARAMS_H_
